
import numpy as np #for all numerical computations
import pandas as pd #similar to dataframe in R
import scipy as sp #scientific computations (includes stats)
import matplotlib.pyplot as plt #plotting 
import scipy.io as sio
import cvxpy as cp # convex optimization
import random
import sys
import time
import gc
import math

from sklearn.utils.extmath import randomized_svd

def spca_sdp(data,s):
    #----------------------------------------------------------------
    #------------SPCA SDP--------------------------------------------
    #----------------------------------------------------------------
    A=np.matmul(np.transpose(data),data)
    n = A.shape[0]

    # Define and solve the CVXPY problem.
    # Create a symmetric matrix variable.
    X = cp.Variable((n,n), symmetric=True)

    # The operator >> denotes matrix inequality.
    constraints = [X >> 0]
    constraints += [cp.trace(X) <= 1]
    constraints += [cp.sum(cp.abs(X)) <= s]

    prob = cp.Problem(cp.Maximize(cp.trace(A @ X)),
                      constraints)
    prob.solve()
    
    #####---------------
    M=300
    np.random.seed(500)
    G= np.random.normal(0, 1, (X.value.shape[0],M))
    P= np.matmul(X.value,G)
    
    y = np.array([0.0] * M)
    for i in range(M):
        y[i]=np.dot(P[:,i],np.matmul(A,P[:,i]))


    y_max=np.transpose(P)[np.where(y==np.max(y))].copy() ### maxinaum of all the elements of y
    y_max=y_max[0]
    
    y_max_s=y_max.copy()
    y_max_abs=np.absolute(y_max_s)

    sid=min(10*s,len(y_max))
    max_id_sdp= y_max_abs.argsort()[-sid:][::-1]
    
    prob=y_max_abs[max_id_sdp]/sum(y_max_abs[max_id_sdp])
    max_ids_sdp= set(np.random.choice(max_id_sdp,s,replace=False, p=prob))

    
    allidx_sdp = set(np.arange(0,len(y_max)))
    remidx_sdp = list(allidx_sdp - max_ids_sdp)
    y_max_s[remidx_sdp] = 0
    
    idx=np.where(y_max_s!=0)[0]
    A_red=A[np.ix_(idx, idx)]
    
    #UA_sdp, SigA_sdp, VTA_sdp=np.linalg.svd(A_red)
    UA_sdp, SigA_sdp, VTA_sdp =randomized_svd(A_red, n_components=1, random_state=10)
    z_norm_sdp = np.zeros(len(y_max))
    
    z_norm_sdp[idx]=VTA_sdp[0].copy()  ### padding other elements with zeros
    f_sdp=np.dot(z_norm_sdp,np.matmul(A,z_norm_sdp))/np.linalg.norm(A, 2)
    nnz_sdp=np.count_nonzero(z_norm_sdp)
    print("sdp done")
    return z_norm_sdp, f_sdp